import numpy as np
import multiprocessing as mp
import pandas as pd
import time

# Sampling script: Simulation_MultiHeadAttention_sqrtn.py
# This script performs simulations and writes results to CSV for plotting.

# Simulation parameters will be set via params dict
# n: Network width
# s: Spatial dimension
# H: Number of heads
# num_runs: Number of simulation runs
# num_processes: Number of parallel processes
# seed: Random seed
# mc_runs: Monte Carlo runs for theoretical density
# C: Clipping threshold for psi

# Globals (initialized in init_globals)
n = None
s = None
H = None
num_runs = None
num_processes = None
seed = None
mc_runs = None
C = None

# Weight sampling W ~ N(0,1/n)
def scaled_weights(shape):
    return np.random.randn(*shape) / np.sqrt(n)

# Initialize globals in each process or main
def init_globals(params):
    global n, s, H, num_runs, num_processes, seed, mc_runs, C
    n = params.get('n', n)
    s = params.get('s', s)
    H = params.get('H', H)
    num_runs = params.get('num_runs', num_runs)
    num_processes = params.get('num_processes', num_processes)
    seed = params.get('seed', seed)
    mc_runs = params.get('mc_runs', mc_runs)
    C = params.get('C', C)
    np.random.seed(seed)

# Softmax function
def softmax(x, axis=-1):
    e_x = np.exp(x)
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

# Single empirical run: output matrix Y shape (s,n)
def single_run(_):
    h1 = np.random.randn(n)
    h2 = np.clip(h1, -C, C)
    Wstack = scaled_weights((s, n, n)).reshape(s, n, n)
    X = Wstack @ h2  # (s,n)
    accum = np.zeros((s, n))
    for _ in range(H):
        Wq = scaled_weights((n, n)); Wk = scaled_weights((n, n))
        Wv = scaled_weights((n, n)); Wo = scaled_weights((n, n))
        Q = X @ Wq; K = X @ Wk; V = X @ Wv
        tV = V @ Wo
        G = Q.dot(K.T) / np.sqrt(n)
        A = softmax(G, axis=1)
        accum += A @ tV
    return accum

# Simulate empirical: returns array (num_runs, s, n)
def simulate_empirical(params):
    init_globals(params)
    with mp.Pool(processes=params['num_processes'], initializer=init_globals, initargs=(params,)) as pool:
        out = pool.map(single_run, range(params['num_runs']))
    return np.stack(out)

# Theoretical Monte Carlo for n^{-1/2}: returns array (mc_runs, s)
def simulate_theoretical(params):
    init_globals(params)
    p = np.random.randn(params['mc_runs'], H, s, s)
    Z = np.random.randn(params['mc_runs'], H, s)
    y = np.zeros((params['mc_runs'], s))
    for i in range(params['mc_runs']):
        for a in range(H):
            logits = p[i, a]
            probs = softmax(logits, axis=1)
            y[i] += probs.dot(Z[i, a])
    return y

if __name__ == '__main__':
    # Settings for experiments
    exp_min, exp_max = 2, 5
    # default_n = 2 ** ((exp_min + exp_max) // 2)
    default_n = 256
    default_H = 2
    base = {
        'n': default_n,
        's': 4,
        'H': default_H,
        'num_runs': 300000,
        'num_processes': 18,
        'seed': 0,
        'mc_runs': 300000,
        'C': 100
    }

    # Prepare lists
    n_vals = [4 ** e for e in range(exp_min, exp_max + 1)]
    # H_vals = [4 ** i for i in range(3)]
    H_vals = [1, 256]

    records_n = []
    records_H = []

    start = time.time()
    # Vary n
    theo = simulate_theoretical(base)
    y_theo = theo[:, 0]
    for n_val in n_vals:
        params = base.copy(); params['n'] = n_val
        emp = simulate_empirical(params)
        y_emp = emp[:, 0, 0]
        for ye, yt in zip(y_emp, y_theo):
            records_n.append({'param': n_val, 'y_emp': ye, 'y_theo': yt})

    # Vary H (use default_n)
    for H_val in H_vals:
        params = base.copy(); params['H'] = H_val
        emp = simulate_empirical(params)
        y_emp = emp[:, 0, 0]
        theo_h = simulate_theoretical(params)
        y_theo_h = theo_h[:, 0]
        for ye, yt in zip(y_emp, y_theo_h):
            records_H.append({'param': H_val, 'y_emp': ye, 'y_theo': yt})

    # Save to CSV
    pd.DataFrame(records_n).to_csv('data_vary_n.csv', index=False)
    pd.DataFrame(records_H).to_csv('data_vary_H.csv', index=False)

    print(f"Sampling done in {time.time()-start:.2f} seconds")
